In [1]:
%matplotlib inline
import pylab
import seaborn
import numpy as np
import nengo
In [2]:
class RegularSpiking(object):
def __init__(self, size, dt=0.001):
self.state = np.zeros(size)
self.threshold = 1.0 / dt
self.output = np.zeros(size)
def __call__(self, t, x):
self.output[:] = 0
self.state += x
spikes = np.where(self.state > self.threshold)
self.output[spikes] = self.threshold
self.state -= self.output
return self.output
Here, we just check if there are any events within dt time, and if there are, we assume there is only one of them.
This is a horrible approximation. Don't do this.
In [3]:
class PoissonSpikingApproximate(object):
def __init__(self, size, seed, dt=0.001):
self.rng = np.random.RandomState(seed=seed)
self.dt = dt
self.value = 1.0 / dt
self.size = size
self.output = np.zeros(size)
def __call__(self, t, x):
self.output[:] = 0
p = 1.0 - np.exp(-x*self.dt)
self.output[p>self.rng.rand(self.size)] = self.value
return self.output
For this attempt, I'm screwing something up in the logic, as it's firing more frequently than it should. The idea is to use the same approach as above to see if any spikes happen during the dt. If there are any spikes, then count that as 1 spike. But now I need to know if there are any more spikes (given that we know at least one spike happened). Since one spike happened at one instant in time, but all the other points in time during that dt could also have a spike (since the whole point of a poisson process is that everything's independent), we can remove that one infinitessimal point in take from dt, leaving us with dt, and we can check if any spikes happened in that remaining time. In other words, just do the same process over again, only considering those neurons who did spike on the previous iteration of this logic. Continue until no neurons have another spike.
I'm not sure if my logic is bad or my coding is bad, but this doesn't work.
In [84]:
class PoissonSpikingExactBad(object):
def __init__(self, size, seed, dt=0.001):
self.rng = np.random.RandomState(seed=seed)
self.dt = dt
self.value = 1.0 / dt
self.size = size
self.output = np.zeros(size)
def __call__(self, t, x):
p = 1.0 - np.exp(-x*self.dt)
self.output[:] = 0
s = np.where(p>self.rng.rand(self.size))[0]
self.output[s] += self.value
count = len(s)
while count > 0:
s2 = np.where(p[s]>self.rng.rand(count))[0]
s = s[s2]
self.output[s] += self.value
count = len(s)
return self.output
Now for one that works. Here we do the approach of actually figuring out when during the time step the events happen, and continue until we fall off the end of the timestep. This is how everyone says to do it.
In [85]:
class PoissonSpikingExact(object):
def __init__(self, size, seed, dt=0.001):
self.rng = np.random.RandomState(seed=seed)
self.dt = dt
self.value = 1.0 / dt
self.size = size
self.output = np.zeros(size)
def next_spike_times(self, rate):
return -np.log(1.0-self.rng.rand(len(rate))) / rate
def __call__(self, t, x):
self.output[:] = 0
next_spikes = self.next_spike_times(x)
s = np.where(next_spikes<self.dt)[0]
count = len(s)
self.output[s] += self.value
while count > 0:
next_spikes[s] += self.next_spike_times(x[s])
s2 = np.where(next_spikes[s]<self.dt)[0]
count = len(s2)
s = s[s2]
self.output[s] += self.value
return self.output
In [86]:
model = nengo.Network()
with model:
freq=10
stim = nengo.Node(lambda t: np.sin(t*np.pi*2*freq))
ens = nengo.Ensemble(n_neurons=5, dimensions=1, neuron_type=nengo.LIFRate(), seed=1)
nengo.Connection(stim, ens, synapse=None)
regular_spikes = nengo.Node(RegularSpiking(ens.n_neurons), size_in=ens.n_neurons)
nengo.Connection(ens.neurons, regular_spikes, synapse=None)
poisson_spikes = nengo.Node(PoissonSpikingExact(ens.n_neurons, seed=1), size_in=ens.n_neurons)
nengo.Connection(ens.neurons, poisson_spikes, synapse=None)
p_rate = nengo.Probe(ens.neurons)
p_regular = nengo.Probe(regular_spikes)
p_poisson = nengo.Probe(poisson_spikes)
sim = nengo.Simulator(model)
sim.run(0.1)
In [87]:
pylab.figure(figsize=(10,8))
pylab.subplot(3,1,1)
pylab.plot(sim.trange(), sim.data[p_rate])
pylab.xlim(0, sim.time)
pylab.ylabel('rate')
pylab.subplot(3,1,2)
import nengo.utils.matplotlib
nengo.utils.matplotlib.rasterplot(sim.trange(), sim.data[p_regular])
pylab.xlim(0, sim.time)
pylab.ylabel('regular spiking')
pylab.subplot(3,1,3)
nengo.utils.matplotlib.rasterplot(sim.trange(), sim.data[p_poisson])
pylab.xlim(0, sim.time)
pylab.ylabel('poisson spiking')
pylab.show()
Let's test the accuracy of these models.
In [88]:
def test_accuracy(cls, rate, T=1):
test_model = nengo.Network()
with test_model:
stim = nengo.Node(rate)
spikes = nengo.Node(cls(1, seed=1), size_in=1)
nengo.Connection(stim, spikes, synapse=None)
p = nengo.Probe(spikes)
sim = nengo.Simulator(test_model)
sim.run(T, progress_bar=False)
return np.mean(sim.data[p])
In [89]:
rates = np.linspace(0, 1000, 11)
result_approx = [test_accuracy(PoissonSpikingApproximate, r) for r in rates]
result_bad = [test_accuracy(PoissonSpikingExactBad, r) for r in rates]
result_exact = [test_accuracy(PoissonSpikingExact, r) for r in rates]
In [91]:
pylab.plot(rates, result_approx, label='spike rate approx')
pylab.plot(rates, result_bad, label='spike rate bad')
pylab.plot(rates, result_exact, label='spike rate exact')
pylab.plot(rates, rates, ls='--', c='k', label='ideal')
pylab.legend(loc='best')
pylab.show()
In [ ]: